feat: add index inspect
This commit is contained in:
		| @@ -108,7 +108,7 @@ class Command: | |||||||
|             ret.append(version) |             ret.append(version) | ||||||
|         return ret |         return ret | ||||||
|  |  | ||||||
|     async def inspectdb(self, tables: List[str]) -> str: |     async def inspectdb(self, tables: List[str] = None) -> str: | ||||||
|         connection = get_app_connection(self.tortoise_config, self.app) |         connection = get_app_connection(self.tortoise_config, self.app) | ||||||
|         dialect = connection.schema_generator.DIALECT |         dialect = connection.schema_generator.DIALECT | ||||||
|         if dialect == "mysql": |         if dialect == "mysql": | ||||||
|   | |||||||
| @@ -12,17 +12,22 @@ class Column(BaseModel): | |||||||
|     comment: Optional[str] |     comment: Optional[str] | ||||||
|     pk: bool |     pk: bool | ||||||
|     unique: bool |     unique: bool | ||||||
|  |     index: bool | ||||||
|     length: Optional[int] |     length: Optional[int] | ||||||
|     extra: Optional[str] |     extra: Optional[str] | ||||||
|     decimal_places: Optional[int] |     decimal_places: Optional[int] | ||||||
|     max_digits: Optional[int] |     max_digits: Optional[int] | ||||||
|  |  | ||||||
|     def translate(self) -> dict: |     def translate(self) -> dict: | ||||||
|         comment = default = length = unique = null = pk = "" |         comment = default = length = index = null = pk = "" | ||||||
|         if self.pk: |         if self.pk: | ||||||
|             pk = "pk=True, " |             pk = "pk=True, " | ||||||
|         if self.unique: |         else: | ||||||
|             unique = "unique=True, " |             if self.unique: | ||||||
|  |                 index = "unique=True, " | ||||||
|  |             else: | ||||||
|  |                 if self.index: | ||||||
|  |                     index = "index=True, " | ||||||
|         if self.data_type in ["varchar", "VARCHAR"]: |         if self.data_type in ["varchar", "VARCHAR"]: | ||||||
|             length = f"max_length={self.length}, " |             length = f"max_length={self.length}, " | ||||||
|         if self.data_type == "decimal": |         if self.data_type == "decimal": | ||||||
| @@ -53,7 +58,7 @@ class Column(BaseModel): | |||||||
|         return { |         return { | ||||||
|             "name": self.name, |             "name": self.name, | ||||||
|             "pk": pk, |             "pk": pk, | ||||||
|             "unique": unique, |             "index": index, | ||||||
|             "null": null, |             "null": null, | ||||||
|             "default": default, |             "default": default, | ||||||
|             "length": length, |             "length": length, | ||||||
| @@ -99,7 +104,7 @@ class Inspect: | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def decimal_field(cls, **kwargs) -> str: |     def decimal_field(cls, **kwargs) -> str: | ||||||
|         return "{name} = fields.DecimalField({pk}{unique}{length}{null}{default}{comment})".format( |         return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format( | ||||||
|             **kwargs |             **kwargs | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @@ -125,21 +130,21 @@ class Inspect: | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def char_field(cls, **kwargs) -> str: |     def char_field(cls, **kwargs) -> str: | ||||||
|         return "{name} = fields.CharField({pk}{unique}{length}{null}{default}{comment})".format( |         return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format( | ||||||
|             **kwargs |             **kwargs | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def int_field(cls, **kwargs) -> str: |     def int_field(cls, **kwargs) -> str: | ||||||
|         return "{name} = fields.IntField({pk}{unique}{comment})".format(**kwargs) |         return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def smallint_field(cls, **kwargs) -> str: |     def smallint_field(cls, **kwargs) -> str: | ||||||
|         return "{name} = fields.SmallIntField({pk}{unique}{comment})".format(**kwargs) |         return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def bigint_field(cls, **kwargs) -> str: |     def bigint_field(cls, **kwargs) -> str: | ||||||
|         return "{name} = fields.BigIntField({pk}{unique}{default}{comment})".format(**kwargs) |         return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def bool_field(cls, **kwargs) -> str: |     def bool_field(cls, **kwargs) -> str: | ||||||
| @@ -147,7 +152,7 @@ class Inspect: | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def uuid_field(cls, **kwargs) -> str: |     def uuid_field(cls, **kwargs) -> str: | ||||||
|         return "{name} = fields.UUIDField({pk}{unique}{default}{comment})".format(**kwargs) |         return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def json_field(cls, **kwargs) -> str: |     def json_field(cls, **kwargs) -> str: | ||||||
|   | |||||||
| @@ -30,9 +30,25 @@ class InspectMySQL(Inspect): | |||||||
|  |  | ||||||
|     async def get_columns(self, table: str) -> List[Column]: |     async def get_columns(self, table: str) -> List[Column]: | ||||||
|         columns = [] |         columns = [] | ||||||
|         sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s" |         sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME | ||||||
|  | from information_schema.COLUMNS c | ||||||
|  |          left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME | ||||||
|  |     and c.TABLE_SCHEMA = s.TABLE_SCHEMA | ||||||
|  |     and c.COLUMN_NAME = s.COLUMN_NAME | ||||||
|  | where c.TABLE_SCHEMA = %s | ||||||
|  |   and c.TABLE_NAME = %s""" | ||||||
|         ret = await self.conn.execute_query_dict(sql, [self.database, table]) |         ret = await self.conn.execute_query_dict(sql, [self.database, table]) | ||||||
|         for row in ret: |         for row in ret: | ||||||
|  |             non_unique = row["NON_UNIQUE"] | ||||||
|  |             if non_unique is None: | ||||||
|  |                 unique = False | ||||||
|  |             else: | ||||||
|  |                 unique = not non_unique | ||||||
|  |             index_name = row["INDEX_NAME"] | ||||||
|  |             if index_name is None: | ||||||
|  |                 index = False | ||||||
|  |             else: | ||||||
|  |                 index = row["INDEX_NAME"] != "PRIMARY" | ||||||
|             columns.append( |             columns.append( | ||||||
|                 Column( |                 Column( | ||||||
|                     name=row["COLUMN_NAME"], |                     name=row["COLUMN_NAME"], | ||||||
| @@ -43,6 +59,8 @@ class InspectMySQL(Inspect): | |||||||
|                     comment=row["COLUMN_COMMENT"], |                     comment=row["COLUMN_COMMENT"], | ||||||
|                     unique=row["COLUMN_KEY"] == "UNI", |                     unique=row["COLUMN_KEY"] == "UNI", | ||||||
|                     extra=row["EXTRA"], |                     extra=row["EXTRA"], | ||||||
|  |                     unque=unique, | ||||||
|  |                     index=index, | ||||||
|                     length=row["CHARACTER_MAXIMUM_LENGTH"], |                     length=row["CHARACTER_MAXIMUM_LENGTH"], | ||||||
|                     max_digits=row["NUMERIC_PRECISION"], |                     max_digits=row["NUMERIC_PRECISION"], | ||||||
|                     decimal_places=row["NUMERIC_SCALE"], |                     decimal_places=row["NUMERIC_SCALE"], | ||||||
|   | |||||||
| @@ -54,7 +54,7 @@ from information_schema.constraint_column_usage const | |||||||
|          right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name) |          right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name) | ||||||
| where c.table_catalog = $1 | where c.table_catalog = $1 | ||||||
|   and c.table_name = $2 |   and c.table_name = $2 | ||||||
|   and c.table_schema = $3;""" |   and c.table_schema = $3""" | ||||||
|         ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) |         ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) | ||||||
|         for row in ret: |         for row in ret: | ||||||
|             columns.append( |             columns.append( | ||||||
| @@ -69,6 +69,7 @@ where c.table_catalog = $1 | |||||||
|                     comment=row["column_comment"], |                     comment=row["column_comment"], | ||||||
|                     pk=row["column_key"] == "PRIMARY KEY", |                     pk=row["column_key"] == "PRIMARY KEY", | ||||||
|                     unique=False,  # can't get this simply |                     unique=False,  # can't get this simply | ||||||
|  |                     index=False,  # can't get this simply | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         return columns |         return columns | ||||||
|   | |||||||
| @@ -25,6 +25,7 @@ class InspectSQLite(Inspect): | |||||||
|         columns = [] |         columns = [] | ||||||
|         sql = f"PRAGMA table_info({table})" |         sql = f"PRAGMA table_info({table})" | ||||||
|         ret = await self.conn.execute_query_dict(sql) |         ret = await self.conn.execute_query_dict(sql) | ||||||
|  |         columns_index = await self._get_columns_index(table) | ||||||
|         for row in ret: |         for row in ret: | ||||||
|             try: |             try: | ||||||
|                 length = row["type"].split("(")[1].split(")")[0] |                 length = row["type"].split("(")[1].split(")")[0] | ||||||
| @@ -38,11 +39,22 @@ class InspectSQLite(Inspect): | |||||||
|                     default=row["dflt_value"], |                     default=row["dflt_value"], | ||||||
|                     length=length, |                     length=length, | ||||||
|                     pk=row["pk"] == 1, |                     pk=row["pk"] == 1, | ||||||
|                     unique=False,  # can't get this simply |                     unique=columns_index.get(row["name"]) == "unique", | ||||||
|  |                     index=columns_index.get(row["name"]) == "index", | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         return columns |         return columns | ||||||
|  |  | ||||||
|  |     async def _get_columns_index(self, table: str): | ||||||
|  |         sql = f"PRAGMA index_list ({table})" | ||||||
|  |         indexes = await self.conn.execute_query_dict(sql) | ||||||
|  |         ret = {} | ||||||
|  |         for index in indexes: | ||||||
|  |             sql = f"PRAGMA index_info({index['name']})" | ||||||
|  |             index_info = (await self.conn.execute_query_dict(sql))[0] | ||||||
|  |             ret[index_info["name"]] = "unique" if index["unique"] else "index" | ||||||
|  |         return ret | ||||||
|  |  | ||||||
|     async def get_all_tables(self) -> List[str]: |     async def get_all_tables(self) -> List[str]: | ||||||
|         sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" |         sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" | ||||||
|         ret = await self.conn.execute_query_dict(sql) |         ret = await self.conn.execute_query_dict(sql) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user