diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index e4a2402..0bdc345 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Type +from typing import Any, List, Type, cast from tortoise import BaseDBAsyncClient, Model from tortoise.backends.base.schema_generator import BaseSchemaGenerator @@ -35,25 +35,26 @@ class BaseDDL: ) _RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"' - def __init__(self, client: "BaseDBAsyncClient"): + def __init__(self, client: "BaseDBAsyncClient") -> None: self.client = client self.schema_generator = self.schema_generator_cls(client) - def create_table(self, model: "Type[Model]"): + def create_table(self, model: "Type[Model]") -> str: return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip( ";" ) - def drop_table(self, table_name: str): + def drop_table(self, table_name: str) -> str: return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) def create_m2m( self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict - ): - through = field_describe.get("through") + ) -> str: + through = cast(str, field_describe.get("through")) description = field_describe.get("description") - reference_id = reference_table_describe.get("pk_field").get("db_column") - db_field_types = reference_table_describe.get("pk_field").get("db_field_types") + pk_field = cast(dict, reference_table_describe.get("pk_field")) + reference_id = pk_field.get("db_column") + db_field_types = cast(dict, pk_field.get("db_field_types")) return self._M2M_TABLE_TEMPLATE.format( table_name=through, backward_table=model._meta.db_table, @@ -73,15 +74,15 @@ class BaseDDL: else "", ) - def drop_m2m(self, table_name: str): + def drop_m2m(self, table_name: str) -> str: return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) - def _get_default(self, model: "Type[Model]", field_describe: dict): + def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any: db_table = model._meta.db_table default = field_describe.get("default") if isinstance(default, Enum): default = default.value - db_column = field_describe.get("db_column") + db_column = cast(str, field_describe.get("db_column")) auto_now_add = field_describe.get("auto_now_add", False) auto_now = field_describe.get("auto_now", False) if default is not None or auto_now_add: @@ -106,25 +107,34 @@ class BaseDDL: default = None return default - def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): + def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: + return self._add_or_modify_column(model, field_describe, is_pk) + + def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str: db_table = model._meta.db_table description = field_describe.get("description") - db_column = field_describe.get("db_column") - db_field_types = field_describe.get("db_field_types") + db_column = cast(str, field_describe.get("db_column")) + db_field_types = cast(dict, field_describe.get("db_field_types")) default = self._get_default(model, field_describe) if default is None: default = "" - return self._ADD_COLUMN_TEMPLATE.format( + if modify: + unique = "" + template = self._MODIFY_COLUMN_TEMPLATE + else: + unique = "UNIQUE" if field_describe.get("unique") else "" + template = self._ADD_COLUMN_TEMPLATE + return template.format( table_name=db_table, column=self.schema_generator._create_string( db_column=db_column, field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), nullable="NOT NULL" if not field_describe.get("nullable") else "", - unique="UNIQUE" if field_describe.get("unique") else "", + unique=unique, comment=self.schema_generator._column_comment_generator( table=db_table, column=db_column, - comment=field_describe.get("description"), + comment=description, ) if description else "", @@ -133,37 +143,17 @@ class BaseDDL: ), ) - def drop_column(self, model: "Type[Model]", column_name: str): + def drop_column(self, model: "Type[Model]", column_name: str) -> str: return self._DROP_COLUMN_TEMPLATE.format( table_name=model._meta.db_table, column_name=column_name ) - def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): - db_table = model._meta.db_table - db_field_types = field_describe.get("db_field_types") - default = self._get_default(model, field_describe) - if default is None: - default = "" - return self._MODIFY_COLUMN_TEMPLATE.format( - table_name=db_table, - column=self.schema_generator._create_string( - db_column=field_describe.get("db_column"), - field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""), - nullable="NOT NULL" if not field_describe.get("nullable") else "", - unique="", - comment=self.schema_generator._column_comment_generator( - table=db_table, - column=field_describe.get("db_column"), - comment=field_describe.get("description"), - ) - if field_describe.get("description") - else "", - is_primary_key=is_pk, - default=default, - ), - ) + def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: + return self._add_or_modify_column(model, field_describe, is_pk, modify=True) - def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str): + def rename_column( + self, model: "Type[Model]", old_column_name: str, new_column_name: str + ) -> str: return self._RENAME_COLUMN_TEMPLATE.format( table_name=model._meta.db_table, old_column_name=old_column_name, @@ -172,7 +162,7 @@ class BaseDDL: def change_column( self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str - ): + ) -> str: return self._CHANGE_COLUMN_TEMPLATE.format( table_name=model._meta.db_table, old_column_name=old_column_name, @@ -180,7 +170,7 @@ class BaseDDL: new_column_type=new_column_type, ) - def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): + def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: return self._ADD_INDEX_TEMPLATE.format( unique="UNIQUE " if unique else "", index_name=self.schema_generator._generate_index_name( @@ -190,7 +180,7 @@ class BaseDDL: column_names=", ".join(self.schema_generator.quote(f) for f in field_names), ) - def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False): + def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: return self._DROP_INDEX_TEMPLATE.format( index_name=self.schema_generator._generate_index_name( "idx" if not unique else "uid", model, field_names @@ -198,45 +188,52 @@ class BaseDDL: table_name=model._meta.db_table, ) - def drop_index_by_name(self, model: "Type[Model]", index_name: str): + def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str: return self._DROP_INDEX_TEMPLATE.format( index_name=index_name, table_name=model._meta.db_table, ) - def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict): + def _generate_fk_name( + self, db_table, field_describe: dict, reference_table_describe: dict + ) -> str: + """Generate fk name""" + db_column = cast(str, field_describe.get("raw_field")) + pk_field = cast(dict, reference_table_describe.get("pk_field")) + to_field = cast(str, pk_field.get("db_column")) + to_table = cast(str, reference_table_describe.get("table")) + return self.schema_generator._generate_fk_name( + from_table=db_table, + from_field=db_column, + to_table=to_table, + to_field=to_field, + ) + + def add_fk( + self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict + ) -> str: db_table = model._meta.db_table db_column = field_describe.get("raw_field") - reference_id = reference_table_describe.get("pk_field").get("db_column") - fk_name = self.schema_generator._generate_fk_name( - from_table=db_table, - from_field=db_column, - to_table=reference_table_describe.get("table"), - to_field=reference_table_describe.get("pk_field").get("db_column"), - ) + pk_field = cast(dict, reference_table_describe.get("pk_field")) + reference_id = pk_field.get("db_column") return self._ADD_FK_TEMPLATE.format( table_name=db_table, - fk_name=fk_name, + fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe), db_column=db_column, table=reference_table_describe.get("table"), field=reference_id, on_delete=field_describe.get("on_delete"), ) - def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict): + def drop_fk( + self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict + ) -> str: db_table = model._meta.db_table - return self._DROP_FK_TEMPLATE.format( - table_name=db_table, - fk_name=self.schema_generator._generate_fk_name( - from_table=db_table, - from_field=field_describe.get("raw_field"), - to_table=reference_table_describe.get("table"), - to_field=reference_table_describe.get("pk_field").get("db_column"), - ), - ) + fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe) + return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name) - def alter_column_default(self, model: "Type[Model]", field_describe: dict): + def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str: db_table = model._meta.db_table default = self._get_default(model, field_describe) return self._ALTER_DEFAULT_TEMPLATE.format( @@ -245,13 +242,13 @@ class BaseDDL: default="SET" + default if default is not None else "DROP DEFAULT", ) - def alter_column_null(self, model: "Type[Model]", field_describe: dict): + def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: return self.modify_column(model, field_describe) - def set_comment(self, model: "Type[Model]", field_describe: dict): + def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: return self.modify_column(model, field_describe) - def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str): + def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str: db_table = model._meta.db_table return self._RENAME_TABLE_TEMPLATE.format( table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name diff --git a/aerich/ddl/postgres/__init__.py b/aerich/ddl/postgres/__init__.py index d5282cd..ff5d318 100644 --- a/aerich/ddl/postgres/__init__.py +++ b/aerich/ddl/postgres/__init__.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Type, cast from tortoise import Model from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator @@ -18,7 +18,7 @@ class PostgresDDL(BaseDDL): _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' - def alter_column_null(self, model: "Type[Model]", field_describe: dict): + def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: db_table = model._meta.db_table return self._ALTER_NULL_TEMPLATE.format( table_name=db_table, @@ -26,9 +26,9 @@ class PostgresDDL(BaseDDL): set_drop="DROP" if field_describe.get("nullable") else "SET", ) - def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): + def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: db_table = model._meta.db_table - db_field_types = field_describe.get("db_field_types") + db_field_types = cast(dict, field_describe.get("db_field_types")) db_column = field_describe.get("db_column") datatype = db_field_types.get(self.DIALECT) or db_field_types.get("") return self._MODIFY_COLUMN_TEMPLATE.format( @@ -38,7 +38,7 @@ class PostgresDDL(BaseDDL): using=f' USING "{db_column}"::{datatype}', ) - def set_comment(self, model: "Type[Model]", field_describe: dict): + def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: db_table = model._meta.db_table return self._SET_COMMENT_TEMPLATE.format( table_name=db_table, diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 8402761..996b0c1 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -79,7 +79,7 @@ class Inspect: def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): self.conn = conn try: - self.database = conn.database + self.database = conn.database # type:ignore[attr-defined] except AttributeError: pass self.tables = tables diff --git a/aerich/inspectdb/mysql.py b/aerich/inspectdb/mysql.py index 16d74f0..db83c16 100644 --- a/aerich/inspectdb/mysql.py +++ b/aerich/inspectdb/mysql.py @@ -60,7 +60,8 @@ where c.TABLE_SCHEMA = %s comment=row["COLUMN_COMMENT"], unique=row["COLUMN_KEY"] == "UNI", extra=row["EXTRA"], - unque=unique, + # TODO: why `unque`? + unque=unique, # type:ignore index=index, length=row["CHARACTER_MAXIMUM_LENGTH"], max_digits=row["NUMERIC_PRECISION"], diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index 0f22bb1..d4a7761 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -1,14 +1,15 @@ -from typing import List, Optional - -from tortoise import BaseDBAsyncClient +from typing import TYPE_CHECKING, List, Optional from aerich.inspectdb import Column, Inspect +if TYPE_CHECKING: + from tortoise.backends.base_postgres.client import BasePostgresClient + class InspectPostgres(Inspect): - def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): + def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None: super().__init__(conn, tables) - self.schema = self.conn.server_settings.get("schema") or "public" + self.schema = conn.server_settings.get("schema") or "public" @property def field_map(self) -> dict: