fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL

* increase:
1. Inspectdb adds DECIMAL, DOUBLE, CHAR, TIME data type matching;
2. Add exception handling, avoid the need to manually create the entire table because a certain data type is not supported.

* fix: aerich inspectdb raise KeyError for double in MySQL

* feat: support command `python -m aerich`

* docs: update changelog

* tests: verify mysql inspectdb for float field

* fix mysql uuid field inspect to be charfield

* refactor: use `db_index=True` instead of `index=True` for inspectdb

* docs: update changelog

---------

Co-authored-by: xiechen <xiechen@jinse.com>
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
This commit is contained in:
程序猿过家家 2025-02-19 16:04:15 +08:00 committed by GitHub
parent 557271c8e1
commit c35282c2a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 104 additions and 62 deletions

View File

@ -11,6 +11,7 @@
#### Fixed
- fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415])
- fix: inspectdb raise KeyError 'int2' for smallint. ([#401])
- fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187])
### Changed
- Refactored version management to use `importlib.metadata.version(__package__)` instead of hardcoded version string ([#412])

View File

@ -38,13 +38,12 @@ class Column(BaseModel):
def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
pk = "primary_key=True, "
else:
if self.unique:
index = "unique=True, "
else:
if self.index:
index = "index=True, "
elif self.index:
index = "db_index=True, "
if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, "
elif self.data_type in ("decimal", "numeric"):
@ -125,62 +124,69 @@ class Inspect:
async def get_all_tables(self) -> list[str]:
raise NotImplementedError
@staticmethod
def get_field_string(
field_class: str, arguments: str = "{null}{default}{comment}", **kwargs
) -> str:
name = kwargs["name"]
field_params = arguments.format(**kwargs).strip().rstrip(",")
return f"{name} = fields.{field_class}({field_params})"
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
return cls.get_field_string("DecimalField", **kwargs)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("TimeField", **kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("DateField", **kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("FloatField", **kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("DatetimeField", **kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("TextField", **kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
arguments = "{pk}{index}{length}{null}{default}{comment}"
return cls.get_field_string("CharField", arguments, **kwargs)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{default}{comment})".format(**kwargs)
def int_field(cls, field_class="IntField", **kwargs) -> str:
arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string(field_class, arguments, **kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{default}{comment})".format(**kwargs)
return cls.int_field("SmallIntField", **kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs)
return cls.int_field("BigIntField", **kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("BooleanField", **kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs)
arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string("UUIDField", arguments, **kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("JSONField", **kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("BinaryField", **kwargs)

View File

@ -12,11 +12,12 @@ class InspectMySQL(Inspect):
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"char": self.char_field,
"char": self.uuid_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"double": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
@ -43,6 +44,8 @@ where c.TABLE_SCHEMA = %s
unique = index = False
if (non_unique := row["NON_UNIQUE"]) is not None:
unique = not non_unique
elif row["COLUMN_KEY"] == "UNI":
unique = True
if (index_name := row["INDEX_NAME"]) is not None:
index = index_name != "PRIMARY"
columns.append(
@ -53,10 +56,8 @@ where c.TABLE_SCHEMA = %s
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
unique=unique,
extra=row["EXTRA"],
# TODO: why `unque`?
unque=unique, # type:ignore
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],

View File

@ -2,7 +2,9 @@ from __future__ import annotations
import asyncio
import os
import sys
from collections.abc import Generator
from pathlib import Path
import pytest
from tortoise import Tortoise, expand_db_url
@ -15,7 +17,7 @@ from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate
from tests._utils import init_db
from tests._utils import chdir, copy_files, init_db, run_shell
db_url = os.getenv("TEST_DB", MEMORY_SQLITE)
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
@ -66,3 +68,30 @@ async def initialize_tests(event_loop, request) -> None:
Migrate.ddl = PostgresDDL(client)
Migrate.dialect = Migrate.ddl.DIALECT
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))
@pytest.fixture
def new_aerich_project(tmp_path: Path):
test_dir = Path(__file__).parent / "tests"
asset_dir = test_dir / "assets" / "fake"
settings_py = asset_dir / "settings.py"
_tests_py = asset_dir / "_tests.py"
db_py = asset_dir / "db.py"
models_py = test_dir / "models.py"
models_second_py = test_dir / "models_second.py"
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
dst_dir = tmp_path / "tests"
dst_dir.mkdir()
dst_dir.joinpath("__init__.py").touch()
copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir)
if should_remove := str(tmp_path) not in sys.path:
sys.path.append(str(tmp_path))
with chdir(tmp_path):
run_shell("python db.py create", capture_output=False)
try:
yield
finally:
if not os.getenv("AERICH_DONT_DROP_FAKE_DB"):
run_shell("python db.py drop", capture_output=False)
if should_remove:
sys.path.remove(str(tmp_path))

View File

@ -1,5 +1,6 @@
import contextlib
import os
import platform
import shlex
import shutil
import subprocess
@ -72,7 +73,12 @@ class Dialect:
return not cls.test_db_url or "sqlite" in cls.test_db_url
WINDOWS = platform.system() == "Windows"
def run_shell(command: str, capture_output=True, **kw) -> str:
if WINDOWS and command.startswith("aerich "):
command = "python -m " + command
r = subprocess.run(shlex.split(command), capture_output=capture_output)
if r.returncode != 0 and r.stderr:
return r.stderr.decode()

View File

@ -93,6 +93,8 @@ class Product(Model):
)
pic = fields.CharField(max_length=200)
body = fields.TextField()
price = fields.FloatField(null=True)
no = fields.UUIDField(db_index=True)
created_at = fields.DatetimeField(auto_now_add=True)
is_deleted = fields.BooleanField(default=False)

View File

@ -2,41 +2,9 @@ from __future__ import annotations
import os
import re
import sys
from pathlib import Path
import pytest
from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate
from tests._utils import chdir, copy_files, run_shell
@pytest.fixture
def new_aerich_project(tmp_path: Path):
test_dir = Path(__file__).parent
asset_dir = test_dir / "assets" / "fake"
settings_py = asset_dir / "settings.py"
_tests_py = asset_dir / "_tests.py"
db_py = asset_dir / "db.py"
models_py = test_dir / "models.py"
models_second_py = test_dir / "models_second.py"
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
dst_dir = tmp_path / "tests"
dst_dir.mkdir()
dst_dir.joinpath("__init__.py").touch()
copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir)
if should_remove := str(tmp_path) not in sys.path:
sys.path.append(str(tmp_path))
with chdir(tmp_path):
run_shell("python db.py create", capture_output=False)
try:
yield
finally:
if not os.getenv("AERICH_DONT_DROP_FAKE_DB"):
run_shell("python db.py drop", capture_output=False)
if should_remove:
sys.path.remove(str(tmp_path))
from tests._utils import Dialect, run_shell
def _append_field(*files: str, name="field_1") -> None:
@ -48,7 +16,7 @@ def _append_field(*files: str, name="field_1") -> None:
def test_fake(new_aerich_project):
if (ddl := getattr(Migrate, "ddl", None)) and isinstance(ddl, SqliteDDL):
if Dialect.is_sqlite():
# TODO: go ahead if sqlite alter-column supported
return
output = run_shell("aerich init -t settings.TORTOISE_ORM")

17
tests/test_inspectdb.py Normal file
View File

@ -0,0 +1,17 @@
from tests._utils import Dialect, run_shell
def test_inspect(new_aerich_project):
if Dialect.is_sqlite():
# TODO: test sqlite after #384 fixed
return
run_shell("aerich init -t settings.TORTOISE_ORM")
run_shell("aerich init-db")
ret = run_shell("aerich inspectdb -t product")
assert ret.startswith("from tortoise import Model, fields")
assert "primary_key=True" in ret
assert "fields.DatetimeField" in ret
assert "fields.FloatField" in ret
assert "fields.UUIDField" in ret
if Dialect.is_mysql():
assert "db_index=True" in ret

View File

@ -981,8 +981,11 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` DROP INDEX `uuid`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `product` ADD `price` DOUBLE",
"ALTER TABLE `product` ADD `no` CHAR(36) NOT NULL",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ADD INDEX `idx_product_no_e4d701` (`no`)",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
@ -1027,8 +1030,11 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE",
"ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `product` DROP COLUMN `price`",
"ALTER TABLE `product` DROP COLUMN `no`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` DROP INDEX `idx_product_no_e4d701`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`",
"ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`",
@ -1074,11 +1080,14 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "product" RENAME COLUMN "is_review" TO "is_reviewed"',
'ALTER TABLE "product" RENAME COLUMN "is_delete" TO "is_deleted"',
'ALTER TABLE "product" ADD "price" DOUBLE PRECISION',
'ALTER TABLE "product" ADD "no" UUID NOT NULL',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)',
'CREATE INDEX IF NOT EXISTS "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE INDEX IF NOT EXISTS "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE INDEX IF NOT EXISTS "idx_product_no_e4d701" ON "product" ("no")',
'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
@ -1118,6 +1127,8 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"',
'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"',
'ALTER TABLE "product" DROP COLUMN "price"',
'ALTER TABLE "product" DROP COLUMN "no"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)',
@ -1126,6 +1137,7 @@ def test_migrate(mocker: MockerFixture):
'DROP INDEX IF EXISTS "idx_email_email_4a1a33"',
'DROP INDEX IF EXISTS "uid_user_usernam_9987ab"',
'DROP INDEX IF EXISTS "uid_product_name_869427"',
'DROP INDEX IF EXISTS "idx_product_no_e4d701"',
'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"',
'CREATE TABLE "config_category" (\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)',