diff --git a/CHANGELOG.md b/CHANGELOG.md index 4834f8e..657e1a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ #### Fixed - fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415]) - fix: inspectdb raise KeyError 'int2' for smallint. ([#401]) +- fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187]) ### Changed - Refactored version management to use `importlib.metadata.version(__package__)` instead of hardcoded version string ([#412]) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index ac9dff0..512946e 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -38,13 +38,12 @@ class Column(BaseModel): def translate(self) -> ColumnInfoDict: comment = default = length = index = null = pk = "" if self.pk: - pk = "pk=True, " + pk = "primary_key=True, " else: if self.unique: index = "unique=True, " - else: - if self.index: - index = "index=True, " + elif self.index: + index = "db_index=True, " if self.data_type in ("varchar", "VARCHAR"): length = f"max_length={self.length}, " elif self.data_type in ("decimal", "numeric"): @@ -125,62 +124,69 @@ class Inspect: async def get_all_tables(self) -> list[str]: raise NotImplementedError + @staticmethod + def get_field_string( + field_class: str, arguments: str = "{null}{default}{comment}", **kwargs + ) -> str: + name = kwargs["name"] + field_params = arguments.format(**kwargs).strip().rstrip(",") + return f"{name} = fields.{field_class}({field_params})" + @classmethod def decimal_field(cls, **kwargs) -> str: - return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format( - **kwargs - ) + return cls.get_field_string("DecimalField", **kwargs) @classmethod def time_field(cls, **kwargs) -> str: - return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("TimeField", **kwargs) @classmethod def date_field(cls, **kwargs) -> str: - return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("DateField", **kwargs) @classmethod def float_field(cls, **kwargs) -> str: - return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("FloatField", **kwargs) @classmethod def datetime_field(cls, **kwargs) -> str: - return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("DatetimeField", **kwargs) @classmethod def text_field(cls, **kwargs) -> str: - return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("TextField", **kwargs) @classmethod def char_field(cls, **kwargs) -> str: - return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format( - **kwargs - ) + arguments = "{pk}{index}{length}{null}{default}{comment}" + return cls.get_field_string("CharField", arguments, **kwargs) @classmethod - def int_field(cls, **kwargs) -> str: - return "{name} = fields.IntField({pk}{index}{default}{comment})".format(**kwargs) + def int_field(cls, field_class="IntField", **kwargs) -> str: + arguments = "{pk}{index}{default}{comment}" + return cls.get_field_string(field_class, arguments, **kwargs) @classmethod def smallint_field(cls, **kwargs) -> str: - return "{name} = fields.SmallIntField({pk}{index}{default}{comment})".format(**kwargs) + return cls.int_field("SmallIntField", **kwargs) @classmethod def bigint_field(cls, **kwargs) -> str: - return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs) + return cls.int_field("BigIntField", **kwargs) @classmethod def bool_field(cls, **kwargs) -> str: - return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("BooleanField", **kwargs) @classmethod def uuid_field(cls, **kwargs) -> str: - return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs) + arguments = "{pk}{index}{default}{comment}" + return cls.get_field_string("UUIDField", arguments, **kwargs) @classmethod def json_field(cls, **kwargs) -> str: - return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("JSONField", **kwargs) @classmethod def binary_field(cls, **kwargs) -> str: - return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs) + return cls.get_field_string("BinaryField", **kwargs) diff --git a/aerich/inspectdb/mysql.py b/aerich/inspectdb/mysql.py index 8986058..1034fd9 100644 --- a/aerich/inspectdb/mysql.py +++ b/aerich/inspectdb/mysql.py @@ -12,11 +12,12 @@ class InspectMySQL(Inspect): "tinyint": self.bool_field, "bigint": self.bigint_field, "varchar": self.char_field, - "char": self.char_field, + "char": self.uuid_field, "longtext": self.text_field, "text": self.text_field, "datetime": self.datetime_field, "float": self.float_field, + "double": self.float_field, "date": self.date_field, "time": self.time_field, "decimal": self.decimal_field, @@ -43,6 +44,8 @@ where c.TABLE_SCHEMA = %s unique = index = False if (non_unique := row["NON_UNIQUE"]) is not None: unique = not non_unique + elif row["COLUMN_KEY"] == "UNI": + unique = True if (index_name := row["INDEX_NAME"]) is not None: index = index_name != "PRIMARY" columns.append( @@ -53,10 +56,8 @@ where c.TABLE_SCHEMA = %s default=row["COLUMN_DEFAULT"], pk=row["COLUMN_KEY"] == "PRI", comment=row["COLUMN_COMMENT"], - unique=row["COLUMN_KEY"] == "UNI", + unique=unique, extra=row["EXTRA"], - # TODO: why `unque`? - unque=unique, # type:ignore index=index, length=row["CHARACTER_MAXIMUM_LENGTH"], max_digits=row["NUMERIC_PRECISION"], diff --git a/conftest.py b/conftest.py index 5b30ae3..f91f6de 100644 --- a/conftest.py +++ b/conftest.py @@ -2,7 +2,9 @@ from __future__ import annotations import asyncio import os +import sys from collections.abc import Generator +from pathlib import Path import pytest from tortoise import Tortoise, expand_db_url @@ -15,7 +17,7 @@ from aerich.ddl.mysql import MysqlDDL from aerich.ddl.postgres import PostgresDDL from aerich.ddl.sqlite import SqliteDDL from aerich.migrate import Migrate -from tests._utils import init_db +from tests._utils import chdir, copy_files, init_db, run_shell db_url = os.getenv("TEST_DB", MEMORY_SQLITE) db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE) @@ -66,3 +68,30 @@ async def initialize_tests(event_loop, request) -> None: Migrate.ddl = PostgresDDL(client) Migrate.dialect = Migrate.ddl.DIALECT request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases())) + + +@pytest.fixture +def new_aerich_project(tmp_path: Path): + test_dir = Path(__file__).parent / "tests" + asset_dir = test_dir / "assets" / "fake" + settings_py = asset_dir / "settings.py" + _tests_py = asset_dir / "_tests.py" + db_py = asset_dir / "db.py" + models_py = test_dir / "models.py" + models_second_py = test_dir / "models_second.py" + copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path) + dst_dir = tmp_path / "tests" + dst_dir.mkdir() + dst_dir.joinpath("__init__.py").touch() + copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir) + if should_remove := str(tmp_path) not in sys.path: + sys.path.append(str(tmp_path)) + with chdir(tmp_path): + run_shell("python db.py create", capture_output=False) + try: + yield + finally: + if not os.getenv("AERICH_DONT_DROP_FAKE_DB"): + run_shell("python db.py drop", capture_output=False) + if should_remove: + sys.path.remove(str(tmp_path)) diff --git a/tests/_utils.py b/tests/_utils.py index c509a2d..f6e067c 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -1,5 +1,6 @@ import contextlib import os +import platform import shlex import shutil import subprocess @@ -72,7 +73,12 @@ class Dialect: return not cls.test_db_url or "sqlite" in cls.test_db_url +WINDOWS = platform.system() == "Windows" + + def run_shell(command: str, capture_output=True, **kw) -> str: + if WINDOWS and command.startswith("aerich "): + command = "python -m " + command r = subprocess.run(shlex.split(command), capture_output=capture_output) if r.returncode != 0 and r.stderr: return r.stderr.decode() diff --git a/tests/models.py b/tests/models.py index 757a642..6879da7 100644 --- a/tests/models.py +++ b/tests/models.py @@ -93,6 +93,8 @@ class Product(Model): ) pic = fields.CharField(max_length=200) body = fields.TextField() + price = fields.FloatField(null=True) + no = fields.UUIDField(db_index=True) created_at = fields.DatetimeField(auto_now_add=True) is_deleted = fields.BooleanField(default=False) diff --git a/tests/test_fake.py b/tests/test_fake.py index 8860457..c311c98 100644 --- a/tests/test_fake.py +++ b/tests/test_fake.py @@ -2,41 +2,9 @@ from __future__ import annotations import os import re -import sys from pathlib import Path -import pytest - -from aerich.ddl.sqlite import SqliteDDL -from aerich.migrate import Migrate -from tests._utils import chdir, copy_files, run_shell - - -@pytest.fixture -def new_aerich_project(tmp_path: Path): - test_dir = Path(__file__).parent - asset_dir = test_dir / "assets" / "fake" - settings_py = asset_dir / "settings.py" - _tests_py = asset_dir / "_tests.py" - db_py = asset_dir / "db.py" - models_py = test_dir / "models.py" - models_second_py = test_dir / "models_second.py" - copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path) - dst_dir = tmp_path / "tests" - dst_dir.mkdir() - dst_dir.joinpath("__init__.py").touch() - copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir) - if should_remove := str(tmp_path) not in sys.path: - sys.path.append(str(tmp_path)) - with chdir(tmp_path): - run_shell("python db.py create", capture_output=False) - try: - yield - finally: - if not os.getenv("AERICH_DONT_DROP_FAKE_DB"): - run_shell("python db.py drop", capture_output=False) - if should_remove: - sys.path.remove(str(tmp_path)) +from tests._utils import Dialect, run_shell def _append_field(*files: str, name="field_1") -> None: @@ -48,7 +16,7 @@ def _append_field(*files: str, name="field_1") -> None: def test_fake(new_aerich_project): - if (ddl := getattr(Migrate, "ddl", None)) and isinstance(ddl, SqliteDDL): + if Dialect.is_sqlite(): # TODO: go ahead if sqlite alter-column supported return output = run_shell("aerich init -t settings.TORTOISE_ORM") diff --git a/tests/test_inspectdb.py b/tests/test_inspectdb.py new file mode 100644 index 0000000..3a63613 --- /dev/null +++ b/tests/test_inspectdb.py @@ -0,0 +1,17 @@ +from tests._utils import Dialect, run_shell + + +def test_inspect(new_aerich_project): + if Dialect.is_sqlite(): + # TODO: test sqlite after #384 fixed + return + run_shell("aerich init -t settings.TORTOISE_ORM") + run_shell("aerich init-db") + ret = run_shell("aerich inspectdb -t product") + assert ret.startswith("from tortoise import Model, fields") + assert "primary_key=True" in ret + assert "fields.DatetimeField" in ret + assert "fields.FloatField" in ret + assert "fields.UUIDField" in ret + if Dialect.is_mysql(): + assert "db_index=True" in ret diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 7b8f227..b4152f0 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -981,8 +981,11 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `product` DROP COLUMN `uuid`", "ALTER TABLE `product` DROP INDEX `uuid`", "ALTER TABLE `product` RENAME COLUMN `image` TO `pic`", + "ALTER TABLE `product` ADD `price` DOUBLE", + "ALTER TABLE `product` ADD `no` CHAR(36) NOT NULL", "ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`", "ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)", + "ALTER TABLE `product` ADD INDEX `idx_product_no_e4d701` (`no`)", "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", @@ -1027,8 +1030,11 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE", "ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)", "ALTER TABLE `product` DROP INDEX `idx_product_name_869427`", + "ALTER TABLE `product` DROP COLUMN `price`", + "ALTER TABLE `product` DROP COLUMN `no`", "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", "ALTER TABLE `product` DROP INDEX `uid_product_name_869427`", + "ALTER TABLE `product` DROP INDEX `idx_product_no_e4d701`", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`", "ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`", @@ -1074,11 +1080,14 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"', 'ALTER TABLE "product" RENAME COLUMN "is_review" TO "is_reviewed"', 'ALTER TABLE "product" RENAME COLUMN "is_delete" TO "is_deleted"', + 'ALTER TABLE "product" ADD "price" DOUBLE PRECISION', + 'ALTER TABLE "product" ADD "no" UUID NOT NULL', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)', 'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)', 'CREATE INDEX IF NOT EXISTS "idx_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE INDEX IF NOT EXISTS "idx_email_email_4a1a33" ON "email" ("email")', + 'CREATE INDEX IF NOT EXISTS "idx_product_no_e4d701" ON "product" ("no")', 'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_name_869427" ON "product" ("name", "type_db_alias")', @@ -1118,6 +1127,8 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"', 'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"', + 'ALTER TABLE "product" DROP COLUMN "price"', + 'ALTER TABLE "product" DROP COLUMN "no"', 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)', @@ -1126,6 +1137,7 @@ def test_migrate(mocker: MockerFixture): 'DROP INDEX IF EXISTS "idx_email_email_4a1a33"', 'DROP INDEX IF EXISTS "uid_user_usernam_9987ab"', 'DROP INDEX IF EXISTS "uid_product_name_869427"', + 'DROP INDEX IF EXISTS "idx_product_no_e4d701"', 'DROP TABLE IF EXISTS "email_user"', 'DROP TABLE IF EXISTS "newmodel"', 'CREATE TABLE "config_category" (\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)',