22 Commits

Author SHA1 Message Date
long2ice
0b01fa38d8 feat: add index inspect 2022-04-05 19:38:08 +08:00
long2ice
801dde15be feat: inspectdb support sqlite 2022-04-01 20:30:36 +08:00
long2ice
75480e2041 Merge remote-tracking branch 'origin/dev' into dev 2022-04-01 19:57:03 +08:00
long2ice
45129cef9f feat: improve inspectdb and support postgres 2022-04-01 19:56:48 +08:00
long2ice
3a0dd2355d Merge pull request #230 from ssilaev/dev
Increase max length of app column
2022-02-09 15:01:39 +08:00
Sergey Silaev
0e71bc16ae Increase max length of app column 2022-02-08 22:14:55 +03:00
long2ice
c39462820c upgrade deps 2022-01-17 22:26:13 +08:00
long2ice
f15cbaf9e0 Support migration for specified index. (#203) 2021-12-29 21:36:23 +08:00
long2ice
15131469df upgrade deps 2021-12-22 16:26:13 +08:00
long2ice
c60c1610f0 Fix pyproject.toml not existing error. (#217) 2021-12-12 22:11:51 +08:00
long2ice
63e8d06157 remove aiomysql 2021-12-08 14:43:33 +08:00
long2ice
68ef8ac676 Fix ci 2021-12-08 14:38:16 +08:00
long2ice
8b5cf6faa0 inspectdb support DATE. (#215) 2021-12-08 14:33:27 +08:00
long2ice
fac00d45cc Remove pydantic dependency. (#198) 2021-10-04 23:05:20 +08:00
long2ice
6f7893d376 Fix section name 2021-09-28 15:07:10 +08:00
long2ice
b1521c4cc7 update version 2021-09-27 19:55:38 +08:00
long2ice
24c1f4cb7d Change default config file from aerich.ini to pyproject.toml. (#197) 2021-09-27 11:05:20 +08:00
long2ice
661f241dac Compatible with old version in indexes 2021-08-31 17:53:17 +08:00
long2ice
01787558d6 Fix test 2021-08-31 17:41:13 +08:00
long2ice
699b0321a4 Support indexes change. (#193) 2021-08-31 17:36:25 +08:00
long2ice
4a83021892 Update FUNDING.yml 2021-08-26 20:39:31 +08:00
long2ice
af63221875 Fix no module found error. (#188) (#189) 2021-08-16 11:14:43 +08:00
23 changed files with 1112 additions and 652 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
custom: ["https://sponsor.long2ice.cn"]
custom: ["https://sponsor.long2ice.io"]

View File

@@ -26,9 +26,9 @@ jobs:
with:
python-version: '3.x'
- name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1
with:
virtualenvs-create: false
run: |
pip install -U pip poetry
poetry config virtualenvs.create false
- name: CI
env:
MYSQL_PASS: root

View File

@@ -12,9 +12,9 @@ jobs:
with:
python-version: '3.x'
- name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1
with:
virtualenvs-create: false
run: |
pip install -U pip poetry
poetry config virtualenvs.create false
- name: Build dists
run: make build
- name: Pypi Publish

View File

@@ -1,7 +1,39 @@
# ChangeLog
## 0.6
### 0.6.3
- Improve `inspectdb` and support `postgres` & `sqlite`.
### 0.6.2
- Support migration for specified index. (#203)
### 0.6.1
- Fix `pyproject.toml` not existing error. (#217)
### 0.6.0
- Change default config file from `aerich.ini` to `pyproject.toml`. (#197)
**Upgrade note:**
1. Run `aerich init -t config.TORTOISE_ORM`.
2. Remove `aerich.ini`.
- Remove `pydantic` dependency. (#198)
- `inspectdb` support `DATE`. (#215)
## 0.5
### 0.5.8
- Support `indexes` change. (#193)
### 0.5.7
- Fix no module found error. (#188) (#189)
### 0.5.6
- Add `Command` class. (#148) (#141) (#123) (#106)

View File

@@ -12,16 +12,15 @@ up:
@poetry update
deps:
@poetry install -E asyncpg -E asyncmy -E aiomysql
@poetry install -E asyncpg -E asyncmy
style: deps
isort -src $(checkfiles)
black $(black_opts) $(checkfiles)
@isort -src $(checkfiles)
@black $(black_opts) $(checkfiles)
check: deps
black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
flake8 $(checkfiles)
bandit -x tests -r $(checkfiles)
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@pflake8 $(checkfiles)
test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test

View File

@@ -7,7 +7,7 @@
## Introduction
Aerich is a database migrations tool for Tortoise-ORM, which is like alembic for SQLAlchemy, or like Django ORM with
Aerich is a database migrations tool for TortoiseORM, which is like alembic for SQLAlchemy, or like Django ORM with
it\'s own migration solution.
## Install
@@ -15,7 +15,7 @@ it\'s own migration solution.
Just install from pypi:
```shell
> pip install aerich
pip install aerich
```
## Quick Start
@@ -27,11 +27,8 @@ Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options:
-V, --version Show the version and exit.
-c, --config TEXT Config file. [default: aerich.ini]
-c, --config TEXT Config file. [default: pyproject.toml]
--app TEXT Tortoise-ORM app name.
-n, --name TEXT Name of section in .ini file to use for aerich config.
[default: aerich]
-h, --help Show this message and exit.
Commands:
@@ -70,10 +67,9 @@ Usage: aerich init [OPTIONS]
Init config file and generate root migrate location.
OOptions:
Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit.
@@ -85,7 +81,7 @@ Initialize the config file and migrations location:
> aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations
Success generate config file aerich.ini
Success write config to pyproject.toml
```
### Init db
@@ -169,7 +165,7 @@ Now your db is rolled back to the specified version.
### Inspect db tables to TortoiseORM model
Currently `inspectdb` only supports MySQL.
Currently `inspectdb` support MySQL & Postgres & SQLite.
```shell
Usage: aerich inspectdb [OPTIONS]
@@ -193,7 +189,44 @@ Inspect a specified table in the default app and redirect to `models.py`:
aerich inspectdb -t user > models.py
```
Note that this command is limited and cannot infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
For example, you table is:
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Now run `aerich inspectdb -t test` to see the generated model:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
```
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
### Multiple databases

View File

@@ -1,5 +1,3 @@
__version__ = "0.5.6"
import os
from pathlib import Path
from typing import List
@@ -10,11 +8,12 @@ from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspectdb import InspectDb
from aerich.inspect.mysql import InspectMySQL
from aerich.inspect.postgres import InspectPostgres
from aerich.inspect.sqlite import InspectSQLite
from aerich.migrate import Migrate
from aerich.models import Aerich
from aerich.utils import (
add_src_path,
get_app_connection,
get_app_connection_name,
get_models_describe,
@@ -29,14 +28,11 @@ class Command:
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
src_folder: str = ".",
):
self.tortoise_config = tortoise_config
self.app = app
self.location = location
self.src_folder = src_folder
Migrate.app = app
add_src_path(src_folder)
async def init(self):
await Migrate.init(self.tortoise_config, self.app, self.location)
@@ -112,10 +108,19 @@ class Command:
ret.append(version)
return ret
async def inspectdb(self, tables: List[str]):
async def inspectdb(self, tables: List[str] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
inspect = InspectDb(connection, tables)
await inspect.inspect()
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update"):
return await Migrate.migrate(name)

View File

@@ -1,21 +1,20 @@
import asyncio
import os
from configparser import ConfigParser
from functools import wraps
from pathlib import Path
from typing import List
import click
import tomlkit
from click import Context, UsageError
from tomlkit.exceptions import NonExistentKey
from tortoise import Tortoise
from aerich import Command
from aerich.enums import Color
from aerich.exceptions import DowngradeError
from aerich.utils import add_src_path, get_tortoise_config
from . import Command, __version__
from .enums import Color
parser = ConfigParser()
from aerich.version import __version__
CONFIG_DEFAULT_VALUES = {
"src_folder": ".",
@@ -31,7 +30,7 @@ def coro(f):
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ != "cli":
if Tortoise._inited:
loop.run_until_complete(Tortoise.close_connections())
return wrapper
@@ -42,39 +41,35 @@ def coro(f):
@click.option(
"-c",
"--config",
default="aerich.ini",
default="pyproject.toml",
show_default=True,
help="Config file.",
)
@click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option(
"-n",
"--name",
default="aerich",
show_default=True,
help="Name of section in .ini file to use for aerich config.",
)
@click.pass_context
@coro
async def cli(ctx: Context, config, app, name):
async def cli(ctx: Context, config, app):
ctx.ensure_object(dict)
ctx.obj["config_file"] = config
ctx.obj["name"] = name
invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init":
if not Path(config).exists():
raise UsageError("You must exec init first", ctx=ctx)
parser.read(config)
location = parser[name]["location"]
tortoise_orm = parser[name]["tortoise_orm"]
src_folder = parser[name].get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
with open(config, "r") as f:
content = f.read()
doc = tomlkit.parse(content)
try:
tool = doc["tool"]["aerich"]
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey:
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0]
command = Command(
tortoise_config=tortoise_config, app=app, location=location, src_folder=src_folder
)
command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command
if invoked_subcommand != "init-db":
if not Path(location, app).exists():
@@ -187,9 +182,6 @@ async def history(ctx: Context):
@coro
async def init(ctx: Context, tortoise_orm, location, src_folder):
config_file = ctx.obj["config_file"]
name = ctx.obj["name"]
if Path(config_file).exists():
return click.secho("Configuration file already created", fg=Color.yellow)
if os.path.isabs(src_folder):
src_folder = os.path.relpath(os.getcwd(), src_folder)
@@ -200,19 +192,25 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
# check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm)
if Path(config_file).exists():
with open(config_file, "r") as f:
content = f.read()
doc = tomlkit.parse(content)
else:
doc = tomlkit.parse("[tool.aerich]")
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
table["src_folder"] = src_folder
doc["tool"]["aerich"] = table
parser.add_section(name)
parser.set(name, "tortoise_orm", tortoise_orm)
parser.set(name, "location", location)
parser.set(name, "src_folder", src_folder)
with open(config_file, "w", encoding="utf-8") as f:
parser.write(f)
with open(config_file, "w") as f:
f.write(tomlkit.dumps(doc))
Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success create migrate location {location}", fg=Color.green)
click.secho(f"Success generate config file {config_file}", fg=Color.green)
click.secho(f"Success write config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migrate location.")
@@ -251,7 +249,8 @@ async def init_db(ctx: Context, safe):
@coro
async def inspectdb(ctx: Context, table: List[str]):
command = ctx.obj["command"]
await command.inspectdb(table)
ret = await command.inspectdb(table)
click.secho(ret)
def main():

31
aerich/coder.py Normal file
View File

@@ -0,0 +1,31 @@
import base64
import json
import pickle # nosec: B301,B403
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Index):
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
}
else:
return super().default(obj)
def object_hook(obj):
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict):
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str):
return json.loads(obj, object_hook=object_hook)

View File

@@ -78,15 +78,11 @@ class BaseDDL:
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add:
if (
field_describe.get("field_type")
in [
"UUIDField",
"TextField",
"JSONField",
]
or is_default_function(default)
):
if field_describe.get("field_type") in [
"UUIDField",
"TextField",
"JSONField",
] or is_default_function(default):
default = ""
else:
try:
@@ -195,6 +191,12 @@ class BaseDDL:
table_name=model._meta.db_table,
)
def drop_index_by_name(self, model: "Type[Model]", index_name: str):
return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name,
table_name=model._meta.db_table,
)
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table

163
aerich/inspect/__init__.py Normal file
View File

@@ -0,0 +1,163 @@
from typing import Any, List, Optional
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: Optional[str]
pk: bool
unique: bool
index: bool
length: Optional[int]
extra: Optional[str]
decimal_places: Optional[int]
max_digits: Optional[int]
def translate(self) -> dict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
else:
if self.unique:
index = "unique=True, "
else:
if self.index:
index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, "
if self.data_type == "decimal":
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
else:
if "::" in self.default:
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
else:
default = f"default={self.default}, "
if self.comment:
comment = f"description='{self.comment}', "
return {
"name": self.name,
"pk": pk,
"index": index,
"null": null,
"default": default,
"length": length,
"comment": comment,
}
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
try:
self.database = conn.database
except AttributeError:
pass
self.tables = tables
@property
def field_map(self) -> dict:
raise NotImplementedError
async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]:
raise NotImplementedError
async def get_all_tables(self) -> List[str]:
raise NotImplementedError
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)

69
aerich/inspect/mysql.py Normal file
View File

@@ -0,0 +1,69 @@
from typing import List
from aerich.inspect import Column, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME
and c.TABLE_SCHEMA = s.TABLE_SCHEMA
and c.COLUMN_NAME = s.COLUMN_NAME
where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
non_unique = row["NON_UNIQUE"]
if non_unique is None:
unique = False
else:
unique = not non_unique
index_name = row["INDEX_NAME"]
if index_name is None:
index = False
else:
index = row["INDEX_NAME"] != "PRIMARY"
columns.append(
Column(
name=row["COLUMN_NAME"],
data_type=row["DATA_TYPE"],
null=row["IS_NULLABLE"] == "YES",
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"],
unque=unique,
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
decimal_places=row["NUMERIC_SCALE"],
)
)
return columns

View File

@@ -0,0 +1,75 @@
from typing import List, Optional
from tortoise import BaseDBAsyncClient
from aerich.inspect import Column, Inspect
class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> dict:
return {
"int4": self.int_field,
"int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"uuid": self.uuid_field,
"jsonb": self.json_field,
"bytea": self.binary_field,
"bool": self.bool_field,
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
t.constraint_type as column_key,
udt_name as data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
from information_schema.constraint_column_usage const
join information_schema.table_constraints t
using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3"""
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],
data_type=row["data_type"],
null=row["is_nullable"] == "YES",
default=row["column_default"],
length=row["character_maximum_length"],
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
index=False, # can't get this simply
)
)
return columns

61
aerich/inspect/sqlite.py Normal file
View File

@@ -0,0 +1,61 @@
from typing import List
from aerich.inspect import Column, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
columns_index = await self._get_columns_index(table)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=columns_index.get(row["name"]) == "unique",
index=columns_index.get(row["name"]) == "index",
)
)
return columns
async def _get_columns_index(self, table: str):
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
for index in indexes:
sql = f"PRAGMA index_info({index['name']})"
index_info = (await self.conn.execute_query_dict(sql))[0]
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

View File

@@ -1,86 +0,0 @@
import sys
from typing import List, Optional
from ddlparse import DdlParse
from tortoise import BaseDBAsyncClient
class InspectDb:
_table_template = "class {table}(Model):\n"
_field_template_mapping = {
"INT": " {field} = fields.IntField({pk}{unique}{comment})",
"SMALLINT": " {field} = fields.IntField({pk}{unique}{comment})",
"TINYINT": " {field} = fields.BooleanField({null}{default}{comment})",
"VARCHAR": " {field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})",
"LONGTEXT": " {field} = fields.TextField({null}{default}{comment})",
"TEXT": " {field} = fields.TextField({null}{default}{comment})",
"DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})",
"FLOAT": " {field} = fields.FloatField({null}{default}{comment})",
}
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.tables = tables
self.DIALECT = conn.schema_generator.DIALECT
async def show_create_tables(self):
if self.DIALECT == "mysql":
if not self.tables:
sql_tables = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{self.conn.database}';" # nosec: B608
ret = await self.conn.execute_query(sql_tables)
self.tables = map(lambda x: x["TABLE_NAME"], ret[1])
for table in self.tables:
sql_show_create_table = f"SHOW CREATE TABLE {table}"
ret = await self.conn.execute_query(sql_show_create_table)
yield ret[1][0]["Create Table"]
else:
raise NotImplementedError("Currently only support MySQL")
async def inspect(self):
ddl_list = self.show_create_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
async for ddl in ddl_list:
parser = DdlParse(ddl, DdlParse.DATABASE.mysql)
table = parser.parse()
name = table.name.title()
columns = table.columns
fields = []
model = self._table_template.format(table=name)
for column_name, column in columns.items():
comment = default = length = unique = null = pk = ""
if column.primary_key:
pk = "pk=True, "
if column.unique:
unique = "unique=True, "
if column.data_type == "VARCHAR":
length = f"max_length={column.length}, "
if not column.not_null:
null = "null=True, "
if column.default is not None:
if column.data_type == "TINYINT":
default = f"default={'True' if column.default == '1' else 'False'}, "
elif column.data_type == "DATETIME":
if "CURRENT_TIMESTAMP" in column.default:
if "ON UPDATE CURRENT_TIMESTAMP" in ddl:
default = "auto_now_add=True, "
else:
default = "auto_now=True, "
else:
default = f"default={column.default}, "
if column.comment:
comment = f"description='{column.comment}', "
field = self._field_template_mapping[column.data_type].format(
field=column_name,
pk=pk,
unique=unique,
length=length,
null=null,
default=default,
comment=comment,
)
fields.append(field)
tables.append(model + "\n".join(fields))
sys.stdout.write(result + "\n\n\n".join(tables))

View File

@@ -1,12 +1,14 @@
import os
from datetime import datetime
from hashlib import md5
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type, Union
import click
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError
from tortoise.indexes import Index
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
@@ -32,7 +34,7 @@ class Migrate:
ddl: BaseDDL
_last_version_content: Optional[dict] = None
app: str
migrate_location: str
migrate_location: Path
dialect: str
_db_version: Optional[str] = None
@@ -138,25 +140,37 @@ class Migrate:
return await cls._generate_diff_sql(name)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m=False):
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
"""
add operator,differentiate fk because fk is order limit
:param operator:
:param upgrade:
:param fk_m2m:
:param fk_m2m_index:
:return:
"""
if upgrade:
if fk_m2m:
if fk_m2m_index:
cls._upgrade_fk_m2m_index_operators.append(operator)
else:
cls.upgrade_operators.append(operator)
else:
if fk_m2m:
if fk_m2m_index:
cls._downgrade_fk_m2m_index_operators.append(operator)
else:
cls.downgrade_operators.append(operator)
@classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
ret = []
for index in indexes:
if isinstance(index, Index):
index.__hash__ = lambda self: md5( # nosec: B303
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
).hexdigest()
ret.append(index)
return ret
@classmethod
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
"""
@@ -192,7 +206,18 @@ class Migrate:
new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
)
)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
@@ -224,7 +249,7 @@ class Migrate:
new_models.get(change[0][1].get("model_name")),
),
upgrade,
fk_m2m=True,
fk_m2m_index=True,
)
elif action == "remove":
add = False
@@ -235,14 +260,19 @@ class Migrate:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True)
cls._add_operator(cls.drop_m2m(table), upgrade, True)
# add unique_together
for index in new_unique_together.difference(old_unique_together):
cls._add_operator(cls._add_index(model, index, True), upgrade, True)
# remove unique_together
for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes
for index in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, index, False), upgrade, True)
# remove indexes
for index in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
old_data_fields = old_model_describe.get("data_fields")
new_data_fields = new_model_describe.get("data_fields")
@@ -356,7 +386,7 @@ class Migrate:
model, fk_field, new_models.get(fk_field.get("python_type"))
),
upgrade,
fk_m2m=True,
fk_m2m_index=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
@@ -371,7 +401,7 @@ class Migrate:
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade,
fk_m2m=True,
fk_m2m_index=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
@@ -457,12 +487,18 @@ class Migrate:
return ret
@classmethod
def _drop_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model)
)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique)
@classmethod
def _add_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, fields_name, unique)

View File

@@ -1,12 +1,15 @@
from tortoise import Model, fields
from aerich.coder import decoder, encoder
MAX_VERSION_LENGTH = 255
MAX_APP_LENGTH = 100
class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20)
content = fields.JSONField()
app = fields.CharField(max_length=MAX_APP_LENGTH)
content = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta:
ordering = ["-id"]

1
aerich/version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "0.6.3"

935
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "aerich"
version = "0.5.6"
version = "0.6.3"
description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0"
@@ -18,11 +18,11 @@ include = ["CHANGELOG.md", "LICENSE", "README.md"]
python = "^3.7"
tortoise-orm = "*"
click = "*"
pydantic = "*"
aiomysql = { version = "*", optional = true }
asyncpg = { version = "*", optional = true }
ddlparse = "*"
asyncmy = { version = "*", optional = true }
pydantic = "*"
dictdiffer = "*"
tomlkit = "*"
[tool.poetry.dev-dependencies]
flake8 = "*"
@@ -34,11 +34,16 @@ pytest-asyncio = "*"
bandit = "*"
pytest-mock = "*"
cryptography = "*"
pyproject-flake8 = "*"
[tool.poetry.extras]
asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"]
aiomysql = ["aiomysql"]
[tool.aerich]
tortoise_orm = "conftest.tortoise_orm"
location = "./migrations"
src_folder = "./."
[build-system]
requires = ["poetry>=0.12"]
@@ -46,3 +51,17 @@ build-backend = "poetry.masonry.api"
[tool.poetry.scripts]
aerich = "aerich.cli:main"
[tool.black]
line-length = 100
target-version = ['py36', 'py37', 'py38', 'py39']
[tool.pytest.ini_options]
asyncio_mode = 'auto'
[tool.mypy]
pretty = true
ignore_missing_imports = true
[tool.flake8]
ignore = 'E501,W503,E203'

View File

@@ -1,2 +0,0 @@
[flake8]
ignore = E501,W503

View File

@@ -65,6 +65,7 @@ class Product(Model):
class Meta:
unique_together = (("name", "type"),)
indexes = (("name", "type"),)
class Config(Model):

View File

@@ -17,6 +17,7 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
@@ -151,6 +152,7 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
@@ -242,6 +244,7 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
@@ -334,6 +337,7 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
@@ -512,6 +516,7 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
@@ -681,6 +686,7 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
@@ -793,6 +799,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
@@ -816,6 +823,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
@@ -843,6 +851,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "user" DROP COLUMN "avatar"',
'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\';',
@@ -865,6 +874,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'DROP INDEX "idx_product_name_869427"',
'DROP INDEX "idx_email_email_4a1a33"',
'DROP INDEX "idx_user_usernam_9987ab"',
'DROP INDEX "uid_product_name_869427"',