feat: add index inspect
This commit is contained in:
		| @@ -108,7 +108,7 @@ class Command: | ||||
|             ret.append(version) | ||||
|         return ret | ||||
|  | ||||
|     async def inspectdb(self, tables: List[str]) -> str: | ||||
|     async def inspectdb(self, tables: List[str] = None) -> str: | ||||
|         connection = get_app_connection(self.tortoise_config, self.app) | ||||
|         dialect = connection.schema_generator.DIALECT | ||||
|         if dialect == "mysql": | ||||
|   | ||||
| @@ -12,17 +12,22 @@ class Column(BaseModel): | ||||
|     comment: Optional[str] | ||||
|     pk: bool | ||||
|     unique: bool | ||||
|     index: bool | ||||
|     length: Optional[int] | ||||
|     extra: Optional[str] | ||||
|     decimal_places: Optional[int] | ||||
|     max_digits: Optional[int] | ||||
|  | ||||
|     def translate(self) -> dict: | ||||
|         comment = default = length = unique = null = pk = "" | ||||
|         comment = default = length = index = null = pk = "" | ||||
|         if self.pk: | ||||
|             pk = "pk=True, " | ||||
|         else: | ||||
|             if self.unique: | ||||
|             unique = "unique=True, " | ||||
|                 index = "unique=True, " | ||||
|             else: | ||||
|                 if self.index: | ||||
|                     index = "index=True, " | ||||
|         if self.data_type in ["varchar", "VARCHAR"]: | ||||
|             length = f"max_length={self.length}, " | ||||
|         if self.data_type == "decimal": | ||||
| @@ -53,7 +58,7 @@ class Column(BaseModel): | ||||
|         return { | ||||
|             "name": self.name, | ||||
|             "pk": pk, | ||||
|             "unique": unique, | ||||
|             "index": index, | ||||
|             "null": null, | ||||
|             "default": default, | ||||
|             "length": length, | ||||
| @@ -99,7 +104,7 @@ class Inspect: | ||||
|  | ||||
|     @classmethod | ||||
|     def decimal_field(cls, **kwargs) -> str: | ||||
|         return "{name} = fields.DecimalField({pk}{unique}{length}{null}{default}{comment})".format( | ||||
|         return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format( | ||||
|             **kwargs | ||||
|         ) | ||||
|  | ||||
| @@ -125,21 +130,21 @@ class Inspect: | ||||
|  | ||||
|     @classmethod | ||||
|     def char_field(cls, **kwargs) -> str: | ||||
|         return "{name} = fields.CharField({pk}{unique}{length}{null}{default}{comment})".format( | ||||
|         return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format( | ||||
|             **kwargs | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def int_field(cls, **kwargs) -> str: | ||||
|         return "{name} = fields.IntField({pk}{unique}{comment})".format(**kwargs) | ||||
|         return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs) | ||||
|  | ||||
|     @classmethod | ||||
|     def smallint_field(cls, **kwargs) -> str: | ||||
|         return "{name} = fields.SmallIntField({pk}{unique}{comment})".format(**kwargs) | ||||
|         return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs) | ||||
|  | ||||
|     @classmethod | ||||
|     def bigint_field(cls, **kwargs) -> str: | ||||
|         return "{name} = fields.BigIntField({pk}{unique}{default}{comment})".format(**kwargs) | ||||
|         return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs) | ||||
|  | ||||
|     @classmethod | ||||
|     def bool_field(cls, **kwargs) -> str: | ||||
| @@ -147,7 +152,7 @@ class Inspect: | ||||
|  | ||||
|     @classmethod | ||||
|     def uuid_field(cls, **kwargs) -> str: | ||||
|         return "{name} = fields.UUIDField({pk}{unique}{default}{comment})".format(**kwargs) | ||||
|         return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs) | ||||
|  | ||||
|     @classmethod | ||||
|     def json_field(cls, **kwargs) -> str: | ||||
|   | ||||
| @@ -30,9 +30,25 @@ class InspectMySQL(Inspect): | ||||
|  | ||||
|     async def get_columns(self, table: str) -> List[Column]: | ||||
|         columns = [] | ||||
|         sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s" | ||||
|         sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME | ||||
| from information_schema.COLUMNS c | ||||
|          left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME | ||||
|     and c.TABLE_SCHEMA = s.TABLE_SCHEMA | ||||
|     and c.COLUMN_NAME = s.COLUMN_NAME | ||||
| where c.TABLE_SCHEMA = %s | ||||
|   and c.TABLE_NAME = %s""" | ||||
|         ret = await self.conn.execute_query_dict(sql, [self.database, table]) | ||||
|         for row in ret: | ||||
|             non_unique = row["NON_UNIQUE"] | ||||
|             if non_unique is None: | ||||
|                 unique = False | ||||
|             else: | ||||
|                 unique = not non_unique | ||||
|             index_name = row["INDEX_NAME"] | ||||
|             if index_name is None: | ||||
|                 index = False | ||||
|             else: | ||||
|                 index = row["INDEX_NAME"] != "PRIMARY" | ||||
|             columns.append( | ||||
|                 Column( | ||||
|                     name=row["COLUMN_NAME"], | ||||
| @@ -43,6 +59,8 @@ class InspectMySQL(Inspect): | ||||
|                     comment=row["COLUMN_COMMENT"], | ||||
|                     unique=row["COLUMN_KEY"] == "UNI", | ||||
|                     extra=row["EXTRA"], | ||||
|                     unque=unique, | ||||
|                     index=index, | ||||
|                     length=row["CHARACTER_MAXIMUM_LENGTH"], | ||||
|                     max_digits=row["NUMERIC_PRECISION"], | ||||
|                     decimal_places=row["NUMERIC_SCALE"], | ||||
|   | ||||
| @@ -54,7 +54,7 @@ from information_schema.constraint_column_usage const | ||||
|          right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name) | ||||
| where c.table_catalog = $1 | ||||
|   and c.table_name = $2 | ||||
|   and c.table_schema = $3;""" | ||||
|   and c.table_schema = $3""" | ||||
|         ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) | ||||
|         for row in ret: | ||||
|             columns.append( | ||||
| @@ -69,6 +69,7 @@ where c.table_catalog = $1 | ||||
|                     comment=row["column_comment"], | ||||
|                     pk=row["column_key"] == "PRIMARY KEY", | ||||
|                     unique=False,  # can't get this simply | ||||
|                     index=False,  # can't get this simply | ||||
|                 ) | ||||
|             ) | ||||
|         return columns | ||||
|   | ||||
| @@ -25,6 +25,7 @@ class InspectSQLite(Inspect): | ||||
|         columns = [] | ||||
|         sql = f"PRAGMA table_info({table})" | ||||
|         ret = await self.conn.execute_query_dict(sql) | ||||
|         columns_index = await self._get_columns_index(table) | ||||
|         for row in ret: | ||||
|             try: | ||||
|                 length = row["type"].split("(")[1].split(")")[0] | ||||
| @@ -38,11 +39,22 @@ class InspectSQLite(Inspect): | ||||
|                     default=row["dflt_value"], | ||||
|                     length=length, | ||||
|                     pk=row["pk"] == 1, | ||||
|                     unique=False,  # can't get this simply | ||||
|                     unique=columns_index.get(row["name"]) == "unique", | ||||
|                     index=columns_index.get(row["name"]) == "index", | ||||
|                 ) | ||||
|             ) | ||||
|         return columns | ||||
|  | ||||
|     async def _get_columns_index(self, table: str): | ||||
|         sql = f"PRAGMA index_list ({table})" | ||||
|         indexes = await self.conn.execute_query_dict(sql) | ||||
|         ret = {} | ||||
|         for index in indexes: | ||||
|             sql = f"PRAGMA index_info({index['name']})" | ||||
|             index_info = (await self.conn.execute_query_dict(sql))[0] | ||||
|             ret[index_info["name"]] = "unique" if index["unique"] else "index" | ||||
|         return ret | ||||
|  | ||||
|     async def get_all_tables(self) -> List[str]: | ||||
|         sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" | ||||
|         ret = await self.conn.execute_query_dict(sql) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user