77 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import List, Optional
 | 
						|
 | 
						|
from tortoise import BaseDBAsyncClient
 | 
						|
 | 
						|
from aerich.inspectdb import Column, Inspect
 | 
						|
 | 
						|
 | 
						|
class InspectPostgres(Inspect):
 | 
						|
    def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
 | 
						|
        super().__init__(conn, tables)
 | 
						|
        self.schema = self.conn.server_settings.get("schema") or "public"
 | 
						|
 | 
						|
    @property
 | 
						|
    def field_map(self) -> dict:
 | 
						|
        return {
 | 
						|
            "int4": self.int_field,
 | 
						|
            "int8": self.int_field,
 | 
						|
            "smallint": self.smallint_field,
 | 
						|
            "varchar": self.char_field,
 | 
						|
            "text": self.text_field,
 | 
						|
            "bigint": self.bigint_field,
 | 
						|
            "timestamptz": self.datetime_field,
 | 
						|
            "float4": self.float_field,
 | 
						|
            "float8": self.float_field,
 | 
						|
            "date": self.date_field,
 | 
						|
            "time": self.time_field,
 | 
						|
            "decimal": self.decimal_field,
 | 
						|
            "numeric": self.decimal_field,
 | 
						|
            "uuid": self.uuid_field,
 | 
						|
            "jsonb": self.json_field,
 | 
						|
            "bytea": self.binary_field,
 | 
						|
            "bool": self.bool_field,
 | 
						|
            "timestamp": self.datetime_field,
 | 
						|
        }
 | 
						|
 | 
						|
    async def get_all_tables(self) -> List[str]:
 | 
						|
        sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
 | 
						|
        ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
 | 
						|
        return list(map(lambda x: x["table_name"], ret))
 | 
						|
 | 
						|
    async def get_columns(self, table: str) -> List[Column]:
 | 
						|
        columns = []
 | 
						|
        sql = f"""select c.column_name,
 | 
						|
       col_description('public.{table}'::regclass, ordinal_position) as column_comment,
 | 
						|
       t.constraint_type as column_key,
 | 
						|
       udt_name as data_type,
 | 
						|
       is_nullable,
 | 
						|
       column_default,
 | 
						|
       character_maximum_length,
 | 
						|
       numeric_precision,
 | 
						|
       numeric_scale
 | 
						|
from information_schema.constraint_column_usage const
 | 
						|
         join information_schema.table_constraints t
 | 
						|
              using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
 | 
						|
         right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
 | 
						|
where c.table_catalog = $1
 | 
						|
  and c.table_name = $2
 | 
						|
  and c.table_schema = $3"""
 | 
						|
        ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
 | 
						|
        for row in ret:
 | 
						|
            columns.append(
 | 
						|
                Column(
 | 
						|
                    name=row["column_name"],
 | 
						|
                    data_type=row["data_type"],
 | 
						|
                    null=row["is_nullable"] == "YES",
 | 
						|
                    default=row["column_default"],
 | 
						|
                    length=row["character_maximum_length"],
 | 
						|
                    max_digits=row["numeric_precision"],
 | 
						|
                    decimal_places=row["numeric_scale"],
 | 
						|
                    comment=row["column_comment"],
 | 
						|
                    pk=row["column_key"] == "PRIMARY KEY",
 | 
						|
                    unique=False,  # can't get this simply
 | 
						|
                    index=False,  # can't get this simply
 | 
						|
                )
 | 
						|
            )
 | 
						|
        return columns
 |