diff --git a/aerich/cli.py b/aerich/cli.py index 20fa4a6..94c6a01 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -36,7 +36,6 @@ def coro(f): def wrapper(*args, **kwargs): loop = asyncio.get_event_loop() loop.run_until_complete(f(*args, **kwargs)) - loop.run_until_complete(Tortoise.close_connections()) return wrapper @@ -221,9 +220,9 @@ async def history(ctx: Context): @click.pass_context @coro async def init( - ctx: Context, - tortoise_orm, - location, + ctx: Context, + tortoise_orm, + location, ): config_file = ctx.obj["config_file"] name = ctx.obj["name"] diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index 691e90f..287e38f 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -59,17 +59,16 @@ class BaseDDL: def drop_m2m(self, field: ManyToManyFieldInstance): return self._DROP_TABLE_TEMPLATE.format(table_name=field.through) - def _get_default(self, model: "Type[Model]", field_object: Field): + def _get_default(self, model: "Type[Model]", field_describe: dict): db_table = model._meta.db_table - default = field_object.default - db_column = field_object.model_field_name - auto_now_add = getattr(field_object, "auto_now_add", False) - auto_now = getattr(field_object, "auto_now", False) + default = field_describe.get('default') + db_column = field_describe.get('db_column') + auto_now_add = field_describe.get("auto_now_add", False) + auto_now = field_describe.get( "auto_now", False) if default is not None or auto_now_add: - if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)): + if field_describe.get('field_type')in ['UUIDField', 'TextField', 'JSONField']: default = "" else: - default = field_object.to_db_value(default, model) try: default = self.schema_generator._column_default_generator( db_table, @@ -104,13 +103,13 @@ class BaseDDL: if description else "", is_primary_key=is_pk, - default=field_describe.get("default"), + default=self._get_default(model,field_describe), ), ) - def drop_column(self, model: "Type[Model]", column_name: str): + def drop_column(self, model: "Type[Model]", field_describe: dict): return self._DROP_COLUMN_TEMPLATE.format( - table_name=model._meta.db_table, column_name=column_name + table_name=model._meta.db_table, column_name=field_describe.get('db_column') ) def modify_column(self, model: "Type[Model]", field_object: Field): @@ -142,7 +141,7 @@ class BaseDDL: ) def change_column( - self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str + self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str ): return self._CHANGE_COLUMN_TEMPLATE.format( table_name=model._meta.db_table, @@ -169,37 +168,34 @@ class BaseDDL: table_name=model._meta.db_table, ) - def add_fk(self, model: "Type[Model]", field: dict): + def add_fk(self, model: "Type[Model]", field_describe: dict, field_describe_target: dict): db_table = model._meta.db_table - db_column = field.get("db_column") + db_column = field_describe.get("raw_field") fk_name = self.schema_generator._generate_fk_name( from_table=db_table, from_field=db_column, - to_table=field.related_model._meta.db_table, + to_table=field_describe.get('name'), to_field=db_column, ) return self._ADD_FK_TEMPLATE.format( table_name=db_table, fk_name=fk_name, db_column=db_column, - table=field.related_model._meta.db_table, + table=field_describe.get('name'), field=db_column, - on_delete=field.get("on_delete"), + on_delete=field_describe.get('on_delete'), ) - def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): - to_field_name = field.to_field_instance.source_field - if not to_field_name: - to_field_name = field.to_field_instance.model_field_name + def drop_fk(self, model: "Type[Model]", field_describe: dict, field_describe_target: dict): db_table = model._meta.db_table return self._DROP_FK_TEMPLATE.format( table_name=db_table, fk_name=self.schema_generator._generate_fk_name( from_table=db_table, - from_field=field.source_field or field.model_field_name + "_id", - to_table=field.related_model._meta.db_table, - to_field=to_field_name, + from_field=field_describe.get('raw_field'), + to_table=field_describe.get('name'), + to_field=field_describe_target.get('db_column'), ), ) diff --git a/aerich/migrate.py b/aerich/migrate.py index f6279a5..7f9b930 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -3,7 +3,7 @@ from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Tuple, Type -from dictdiffer import diff +import click from tortoise import ( BackwardFKRelation, BackwardOneToOneRelation, @@ -204,30 +204,49 @@ class Migrate: :param upgrade: :return: """ - for change in diff(old_model_describe, new_model_describe): - action, field_type, fields = change - is_pk = field_type == "pk_field" - if action == "add": - for field in fields: - _, field_describe = field - cls._add_operator( - cls._add_field( - cls._get_model(new_model_describe.get("name").split(".")[1]), - field_describe, - is_pk, - ), - upgrade, - ) - elif action == "remove": - for field in fields: - _, field_describe = field - cls._add_operator( - cls._remove_field( - cls._get_model(new_model_describe.get("name").split(".")[1]), - field_describe, - ), - upgrade, - ) + + old_unique_together = old_model_describe.get('unique_together') + new_unique_together = new_model_describe.get('unique_together') + + old_data_fields = old_model_describe.get('data_fields') + new_data_fields = new_model_describe.get('data_fields') + + old_data_fields_name = list(map(lambda x: x.get('name'), old_data_fields)) + new_data_fields_name = list(map(lambda x: x.get('name'), new_data_fields)) + + model = cls._get_model(new_model_describe.get('name').split('.')[1]) + # add fields + for new_data_field_name in set(new_data_fields_name).difference(set(old_data_fields_name)): + cls._add_operator( + cls._add_field(model, next(filter(lambda x: x.get('name') == new_data_field_name, new_data_fields))), + upgrade) + # remove fields + for old_data_field_name in set(old_data_fields_name).difference(set(new_data_fields_name)): + cls._add_operator( + cls._remove_field(model, next(filter(lambda x: x.get('name') == old_data_field_name, old_data_fields))), + upgrade) + + old_fk_fields = old_model_describe.get('fk_fields') + new_fk_fields = new_model_describe.get('fk_fields') + + old_fk_fields_name = list(map(lambda x: x.get('name'), old_fk_fields)) + new_fk_fields_name = list(map(lambda x: x.get('name'), new_fk_fields)) + + # add fk + for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)): + fk_field = next(filter(lambda x: x.get('name') == new_fk_field_name, new_fk_fields)) + cls._add_operator( + cls._add_fk(model, fk_field, + next(filter(lambda x: x.get('db_column') == fk_field.get('raw_field'), new_data_fields))), + upgrade) + # drop fk + for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)): + old_fk_field = next(filter(lambda x: x.get('name') == old_fk_field_name, old_fk_fields)) + cls._add_operator( + cls._drop_fk( + model, old_fk_field, + next(filter(lambda x: x.get('db_column') == old_fk_field.get('raw_field'), old_data_fields))), + upgrade) @classmethod def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): @@ -273,12 +292,8 @@ class Migrate: return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation)) @classmethod - def _add_field(cls, model: Type[Model], field: dict, is_pk: bool = False): - if field.get("field_type") == "ForeignKeyFieldInstance": - return cls.ddl.add_fk(model, field) - if field.get("field_type") == "ManyToManyFieldInstance": - return cls.ddl.create_m2m_table(model, field) - return cls.ddl.add_column(model, field, is_pk) + def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False): + return cls.ddl.add_column(model, field_describe, is_pk) @classmethod def _alter_default(cls, model: Type[Model], field: Field): @@ -297,16 +312,12 @@ class Migrate: return cls.ddl.modify_column(model, field) @classmethod - def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): - return cls.ddl.drop_fk(model, field) + def _drop_fk(cls, model: Type[Model], field_describe: dict, field_describe_target: dict): + return cls.ddl.drop_fk(model, field_describe, field_describe_target) @classmethod - def _remove_field(cls, model: Type[Model], field: Field): - if isinstance(field, ForeignKeyFieldInstance): - return cls.ddl.drop_fk(model, field) - if isinstance(field, ManyToManyFieldInstance): - return cls.ddl.drop_m2m(field) - return cls.ddl.drop_column(model, field.model_field_name) + def _remove_field(cls, model: Type[Model], field_describe: dict): + return cls.ddl.drop_column(model, field_describe) @classmethod def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field): @@ -322,24 +333,14 @@ class Migrate: ) @classmethod - def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): + def _add_fk(cls, model: Type[Model], field_describe: dict, field_describe_target: dict): """ add fk :param model: :param field: :return: """ - return cls.ddl.add_fk(model, field) - - @classmethod - def _remove_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): - """ - drop fk - :param model: - :param field: - :return: - """ - return cls.ddl.drop_fk(model, field) + return cls.ddl.add_fk(model, field_describe, field_describe_target) @classmethod def _merge_operators(cls): diff --git a/poetry.lock b/poetry.lock index ad415ea..a0bc04d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -138,20 +138,6 @@ python-versions = "*" [package.dependencies] pyparsing = "*" -[[package]] -name = "dictdiffer" -version = "0.8.1" -description = "Dictdiffer is a library that helps you to diff and patch dictionaries." -category = "main" -optional = false -python-versions = "*" - -[package.extras] -all = ["Sphinx (>=1.4.4)", "sphinx-rtd-theme (>=0.1.9)", "check-manifest (>=0.25)", "coverage (>=4.0)", "isort (>=4.2.2)", "mock (>=1.3.0)", "pydocstyle (>=1.0.0)", "pytest-cov (>=1.8.0)", "pytest-pep8 (>=1.0.6)", "pytest (>=2.8.0)", "tox (>=3.7.0)", "numpy (>=1.11.0)"] -docs = ["Sphinx (>=1.4.4)", "sphinx-rtd-theme (>=0.1.9)"] -numpy = ["numpy (>=1.11.0)"] -tests = ["check-manifest (>=0.25)", "coverage (>=4.0)", "isort (>=4.2.2)", "mock (>=1.3.0)", "pydocstyle (>=1.0.0)", "pytest-cov (>=1.8.0)", "pytest-pep8 (>=1.0.6)", "pytest (>=2.8.0)", "tox (>=3.7.0)"] - [[package]] name = "execnet" version = "1.8.0" @@ -561,7 +547,7 @@ dbdrivers = ["aiomysql", "asyncpg"] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "f4ef33a953946570d6d35a479dad75768cd3c6a72e5953c68f2de1566c40873b" +content-hash = "9adf7beba99d615c71a9148391386c9016cbafc7c11c5fc3ad81c8ec61026236" [metadata.files] aiomysql = [ @@ -633,10 +619,6 @@ ddlparse = [ {file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"}, {file = "ddlparse-1.9.0.tar.gz", hash = "sha256:cdffcf2f692f304a23c8e903b00afd7e83a920b79a2ff4e2f25c875b369d4f58"}, ] -dictdiffer = [ - {file = "dictdiffer-0.8.1-py2.py3-none-any.whl", hash = "sha256:d79d9a39e459fe33497c858470ca0d2e93cb96621751de06d631856adfd9c390"}, - {file = "dictdiffer-0.8.1.tar.gz", hash = "sha256:1adec0d67cdf6166bda96ae2934ddb5e54433998ceab63c984574d187cc563d2"}, -] execnet = [ {file = "execnet-1.8.0-py2.py3-none-any.whl", hash = "sha256:7a13113028b1e1cc4c6492b28098b3c6576c9dccc7973bfe47b342afadafb2ac"}, {file = "execnet-1.8.0.tar.gz", hash = "sha256:b73c5565e517f24b62dea8a5ceac178c661c4309d3aa0c3e420856c072c411b4"}, diff --git a/pyproject.toml b/pyproject.toml index c146d41..01cf9a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,12 @@ include = ["CHANGELOG.md", "LICENSE", "README.md"] [tool.poetry.dependencies] python = "^3.7" -tortoise-orm = "*" +tortoise-orm = "^0.16.21" click = "*" pydantic = "*" -aiomysql = {version = "*", optional = true} -asyncpg = {version = "*", optional = true} +aiomysql = { version = "*", optional = true } +asyncpg = { version = "*", optional = true } ddlparse = "*" -dictdiffer = "*" [tool.poetry.dev-dependencies] flake8 = "*" diff --git a/tests/models.py b/tests/models.py index db14ba1..a3ee9ac 100644 --- a/tests/models.py +++ b/tests/models.py @@ -28,12 +28,12 @@ class User(Model): is_active = fields.BooleanField(default=True, description="Is Active") is_superuser = fields.BooleanField(default=False, description="Is SuperUser") avatar = fields.CharField(max_length=200, default="") + intro = fields.TextField(default="") class Email(Model): email = fields.CharField(max_length=200) is_primary = fields.BooleanField(default=False) - user = fields.ForeignKeyField("models.User", db_constraint=False) class Category(Model): diff --git a/tests/test_ddl.py b/tests/test_ddl.py index fcba3ae..410d5a4 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -58,7 +58,7 @@ def test_drop_table(): def test_add_column(): - ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name")) + ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name").describe(False)) if isinstance(Migrate.ddl, MysqlDDL): assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL" else: @@ -180,7 +180,7 @@ def test_drop_index(): def test_add_fk(): - ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user")) + ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user").describe(False)) if isinstance(Migrate.ddl, MysqlDDL): assert ( ret